package spark.pipeline.KDD99

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.{SparkConf, SparkContext}

object Kdd99Tree {
    def main(args: Array[String]): Unit = {
        val conf = new SparkConf().setAppName("DecisionTree").setMaster("local[16]")
        val sc = new SparkContext(conf)
        sc.setLogLevel("ERROR")

        val spark = SparkSession.builder().appName("Kdd99").config("example", "some-value").getOrCreate()

        val file = "C:/Users/Lenovo/Desktop/Working/Python/data/kddcup.data_10_percent_corrected.csv"
        //读取数据并给数据添加表头
        val data = spark.read.csv(file)

        val df = data
          .toDF("duration", "protocol_type", "service", "flag", "src_bytes", "dst_bytes", "land",
              "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
              "su_attempted", "num_root", "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
              "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
              "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate", "dst_host_count",
              "dst_host_srv_count", "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
              "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate",
              "dst_host_srv_rerror_rate", "label")

        //修改"protocol_type"这一列为数值型
        val pretocol_indexer = new StringIndexer().setInputCol("protocol_type").setOutputCol("protocol_typeIndex").fit(dataFrame)
        val indexed_0 = pretocol_indexer.transform(dataFrame)

        //修改"service"这一列为数值型
        val service_indexer = new StringIndexer().setInputCol("service").setOutputCol("serviceIndex").fit(indexed_0)
        val indexed_1 = service_indexer.transform(indexed_0)

        //修改"flag"这一列为数值型
        val flag_indexer = new StringIndexer().setInputCol("flag").setOutputCol("flagIndex").fit(indexed_1)
        val indexed_2 = flag_indexer.transform(indexed_1)

        //修改"label"这一列为数值型
        val label_indexer = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(indexed_2)
        val indexed_df = label_indexer.transform(indexed_2)

        //删除原有的类别列
        val df_final = indexed_df.drop("protocol_type").drop("service")
          .drop("flag").drop("label")

        //合并前41列为features
        val assembler = new VectorAssembler().setInputCols(Array("duration", "src_bytes", "dst_bytes", "land",
            "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
            "su_attempted", "num_root", "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
            "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
            "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate", "dst_host_count",
            "dst_host_srv_count", "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
            "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate",
            "dst_host_srv_rerror_rate", "protocol_typeIndex", "serviceIndex", "flagIndex"))
          .setOutputCol("features")

        val cols = df_final.columns.map(f => col(f).cast(DoubleType))
        val df_finalDataFrame = assembler.transform(df_final.select(cols: _*))

        //        df_finalDataFrame.show(10,false)
        val featuresDataFrame = df_finalDataFrame.select("labelIndex", "features")
          .withColumnRenamed("labelIndex", "label")
        featuresDataFrame.cache()
        featuresDataFrame.show(5, false)


        val labelIndexer = new StringIndexer().setInputCol("label")
          .setOutputCol("indexedLabel").fit(featuresDataFrame)

        val featureIndexer = new VectorIndexer().setInputCol("features")
          .setOutputCol("indexedFeatures").setMaxCategories(10).fit(featuresDataFrame)

        val labelConverter = new IndexToString().setInputCol("prediction")
          .setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

        val Array(trainData, testData) = featuresDataFrame.randomSplit(Array(0.7, 0.3))


        //决策树
        val dtClassifier = new DecisionTreeClassifier()
          .setLabelCol("indexedLabel")
          .setFeaturesCol("indexedFeatures")
          .setMaxBins(100).setMaxDepth(8).setMinInfoGain(0.01)

        val lrPipeline = new Pipeline()
          .setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))

        val dtPipelineModel = lrPipeline.fit(trainData)
        val dtPredictions = dtPipelineModel.transform(testData)
        dtPredictions.show(10, false)

        val evaluator = new MulticlassClassificationEvaluator()
          .setPredictionCol("prediction")
          .setLabelCol("indexedLabel")

        val dtAccuracy = evaluator.evaluate(dtPredictions)
        println("测试机的正确率为:" + dtAccuracy)

        val treeModelClassifier = dtPipelineModel.stages(2)
          .asInstanceOf[DecisionTreeClassificationModel]
        println("Learned classification tree model: \n" + treeModelClassifier.toDebugString)

    }
}
